-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[GlobalISel] Add constant matcher for APInt #151357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-globalisel Author: None (jyli0116) ChangesChanged m_SpecificICst, m_SpecificICstSplat and m_SpecificICstorSplat to match against APInt as well. Full diff: https://github.com/llvm/llvm-project/pull/151357.diff 4 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index c0d3a12cbcb41..e8d9bc03f6428 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -192,24 +192,35 @@ m_GFCstOrSplat(std::optional<FPValueAndVReg> &FPValReg) {
/// Matcher for a specific constant value.
struct SpecificConstantMatch {
- int64_t RequestedVal;
- SpecificConstantMatch(int64_t RequestedVal) : RequestedVal(RequestedVal) {}
+ APInt RequestedVal;
+ SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
- int64_t MatchedVal;
- return mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal;
+ APInt MatchedVal;
+ if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
+ if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
+ RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
+ else
+ MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
+
+ return APInt::isSameValue(MatchedVal, RequestedVal);
+ }
+ return false;
}
};
/// Matches a constant equal to \p RequestedValue.
+inline SpecificConstantMatch m_SpecificICst(APInt RequestedValue) {
+ return SpecificConstantMatch(std::move(RequestedValue));
+}
+
inline SpecificConstantMatch m_SpecificICst(int64_t RequestedValue) {
- return SpecificConstantMatch(RequestedValue);
+ return SpecificConstantMatch(APInt(64, RequestedValue, /* isSigned */ true));
}
/// Matcher for a specific constant splat.
struct SpecificConstantSplatMatch {
- int64_t RequestedVal;
- SpecificConstantSplatMatch(int64_t RequestedVal)
- : RequestedVal(RequestedVal) {}
+ APInt RequestedVal;
+ SpecificConstantSplatMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
/* AllowUndef */ false);
@@ -217,19 +228,31 @@ struct SpecificConstantSplatMatch {
};
/// Matches a constant splat of \p RequestedValue.
+inline SpecificConstantSplatMatch m_SpecificICstSplat(APInt RequestedValue) {
+ return SpecificConstantSplatMatch(std::move(RequestedValue));
+}
+
inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
- return SpecificConstantSplatMatch(RequestedValue);
+ return SpecificConstantSplatMatch(
+ APInt(64, RequestedValue, /* isSigned */ true));
}
/// Matcher for a specific constant or constant splat.
struct SpecificConstantOrSplatMatch {
- int64_t RequestedVal;
- SpecificConstantOrSplatMatch(int64_t RequestedVal)
+ APInt RequestedVal;
+ SpecificConstantOrSplatMatch(APInt RequestedVal)
: RequestedVal(RequestedVal) {}
bool match(const MachineRegisterInfo &MRI, Register Reg) {
- int64_t MatchedVal;
- if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal)
- return true;
+ APInt MatchedVal;
+ if (mi_match(Reg, MRI, m_ICst(MatchedVal))) {
+ if (MatchedVal.getBitWidth() > RequestedVal.getBitWidth())
+ RequestedVal = RequestedVal.sext(MatchedVal.getBitWidth());
+ else
+ MatchedVal = MatchedVal.sext(RequestedVal.getBitWidth());
+
+ if (APInt::isSameValue(MatchedVal, RequestedVal))
+ return true;
+ }
return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
/* AllowUndef */ false);
}
@@ -237,18 +260,24 @@ struct SpecificConstantOrSplatMatch {
/// Matches a \p RequestedValue constant or a constant splat of \p
/// RequestedValue.
+inline SpecificConstantOrSplatMatch
+m_SpecificICstOrSplat(APInt RequestedValue) {
+ return SpecificConstantOrSplatMatch(std::move(RequestedValue));
+}
+
inline SpecificConstantOrSplatMatch
m_SpecificICstOrSplat(int64_t RequestedValue) {
- return SpecificConstantOrSplatMatch(RequestedValue);
+ return SpecificConstantOrSplatMatch(
+ APInt(64, RequestedValue, /* isSigned */ true));
}
-///{
/// Convenience matchers for specific integer values.
-inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); }
+inline SpecificConstantMatch m_ZeroInt() {
+ return SpecificConstantMatch(APInt(64, 0));
+}
inline SpecificConstantMatch m_AllOnesInt() {
- return SpecificConstantMatch(-1);
+ return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true));
}
-///}
/// Matcher for a specific register.
struct SpecificRegisterMatch {
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 66c960fe12c68..5c27605c26883 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -459,12 +459,24 @@ LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef);
+/// Return true if the specified register is defined by G_BUILD_VECTOR or
+/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
+LLVM_ABI bool isBuildVectorConstantSplat(const Register Reg,
+ const MachineRegisterInfo &MRI,
+ APInt SplatValue, bool AllowUndef);
+
/// Return true if the specified instruction is a G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef);
+/// Return true if the specified instruction is a G_BUILD_VECTOR or
+/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
+LLVM_ABI bool isBuildVectorConstantSplat(const MachineInstr &MI,
+ const MachineRegisterInfo &MRI,
+ APInt SplatValue, bool AllowUndef);
+
/// Return true if the specified instruction is a G_BUILD_VECTOR or
/// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
LLVM_ABI bool isBuildVectorAllZeros(const MachineInstr &MI,
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index f48bfc06c14be..8955dd0370539 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1401,6 +1401,21 @@ bool llvm::isBuildVectorConstantSplat(const Register Reg,
return false;
}
+bool llvm::isBuildVectorConstantSplat(const Register Reg,
+ const MachineRegisterInfo &MRI,
+ APInt SplatValue, bool AllowUndef) {
+ if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef)) {
+ if (SplatValAndReg->Value.getBitWidth() < SplatValue.getBitWidth())
+ return APInt::isSameValue(
+ SplatValAndReg->Value.sext(SplatValue.getBitWidth()), SplatValue);
+ return APInt::isSameValue(
+ SplatValAndReg->Value,
+ SplatValue.sext(SplatValAndReg->Value.getBitWidth()));
+ }
+
+ return false;
+}
+
bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
const MachineRegisterInfo &MRI,
int64_t SplatValue, bool AllowUndef) {
@@ -1408,6 +1423,13 @@ bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
AllowUndef);
}
+bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
+ const MachineRegisterInfo &MRI,
+ APInt SplatValue, bool AllowUndef) {
+ return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
+ AllowUndef);
+}
+
std::optional<APInt>
llvm::getIConstantSplatVal(const Register Reg, const MachineRegisterInfo &MRI) {
if (auto SplatValAndReg =
diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
index 25eb67e981588..1e0653b61e8f8 100644
--- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
@@ -634,17 +634,25 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstant) {
auto FortyTwo = B.buildConstant(LLT::scalar(64), 42);
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(42)));
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(123)));
+ EXPECT_TRUE(
+ mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 42))));
+ EXPECT_FALSE(
+ mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICst(APInt(64, 123))));
// Test that this works inside of a more complex pattern.
LLT s64 = LLT::scalar(64);
auto MIBAdd = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(42)));
+ EXPECT_TRUE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 42))));
// Wrong constant.
EXPECT_FALSE(mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(123)));
+ EXPECT_FALSE(
+ mi_match(MIBAdd.getReg(2), *MRI, m_SpecificICst(APInt(64, 123))));
// No constant on the LHS.
EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42)));
+ EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(APInt(64, 42))));
}
TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
@@ -664,6 +672,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43)));
EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42)));
+ EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+ m_SpecificICstSplat(APInt(64, 42))));
+ EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+ m_SpecificICstSplat(APInt(64, 43))));
+ EXPECT_FALSE(
+ mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(APInt(64, 42))));
+
MachineInstrBuilder NonConstantSplat =
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
@@ -673,8 +688,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43)));
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42)));
+ EXPECT_TRUE(
+ mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
+ EXPECT_FALSE(
+ mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 43))));
+ EXPECT_FALSE(
+ mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(APInt(64, 42))));
+
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42)));
+ EXPECT_FALSE(
+ mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(APInt(64, 42))));
}
TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
@@ -695,6 +719,13 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43)));
EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
+ EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+ m_SpecificICstOrSplat(APInt(64, 42))));
+ EXPECT_FALSE(mi_match(FortyTwoSplat.getReg(0), *MRI,
+ m_SpecificICstOrSplat(APInt(64, 43))));
+ EXPECT_TRUE(
+ mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
+
MachineInstrBuilder NonConstantSplat =
B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
@@ -704,8 +735,17 @@ TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43)));
EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42)));
+ EXPECT_TRUE(
+ mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
+ EXPECT_FALSE(
+ mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 43))));
+ EXPECT_FALSE(
+ mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
+
MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
+ EXPECT_TRUE(
+ mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(APInt(64, 42))));
}
TEST_F(AArch64GISelMITest, MatchZeroInt) {
|
c76da1c to
66454fe
Compare
| APInt RequestedVal; | ||
| SpecificConstantMatch(APInt RequestedVal) : RequestedVal(RequestedVal) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const ref
| /// Convenience matchers for specific integer values. | ||
| inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); } | ||
| inline SpecificConstantMatch m_ZeroInt() { | ||
| return SpecificConstantMatch(APInt(64, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use the isNullValue in APInt (really for all of these, should follow along with the IR pattern matcher structure)
| } | ||
| inline SpecificConstantMatch m_AllOnesInt() { | ||
| return SpecificConstantMatch(-1); | ||
| return SpecificConstantMatch(APInt(64, -1, /* isSigned */ true)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use getAllOnes(64) too, although I am unsure about treating it like a isSigned. It might not be necessary here because we would not see a larger value?
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
Changed m_SpecificICst, m_SpecificICstSplat and m_SpecificICstorSplat to match against APInt as well.